Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Use mm in subclass #128

Closed
wants to merge 13 commits into from
Closed

Use mm in subclass #128

wants to merge 13 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 21, 2023

Summary

We use the dispatching mechanism to mm

TODO

  • Hook on to float8_tensor the amax_buffer to be filled under dispatch
  • Update emulate path

Note

Vasiliy has already started this here:
#28

Some things have changed though since then, we are outputing by default in higher precision. However I still need to replicate the amax_buffer filling here and store on float8_tensor passed in

Corresponding core changes to get as far as possible in compile for aot_eager
pytorch/pytorch#111735

Checking against fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13cd1bd0>
attr=_data
attr_fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13c34d00>
attr=_scale
attr_fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13c34d00>

Current Compile Progress

  • backend = "eager_only", full_graph = False: ✅
  • backend = "eager_only", full_graph = False: ❌
E       torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(to_float8) [TensorVariable(), TensorVariable(), ConstantVariable(dtype), TensorVariable()] {}
  • backend = "aot_eager", full_graph = False: ❌
  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4187, in convert
    assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
AssertionError:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 21, 2023
@drisspg drisspg force-pushed the use_mm_in_subclass branch 2 times, most recently from 146c6a1 to 5fb21a4 Compare October 21, 2023 02:20
@drisspg drisspg force-pushed the use_mm_in_subclass branch from 5fb21a4 to f8b8d89 Compare October 21, 2023 02:28
# return Float8Tensor(tensors["_data"], tensors["_scale"], metadatas[0])
return Float8Tensor(inner_tensors["_data"], metadata["_scale"], metadata["_orig_dtype"])
assert len(inner_tensors) == 2
return Float8Tensor(inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh okay I had a conjecture that the way we were using subclasses was not really right, so with this +

commenting out these two lines:
https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L2791-L2792

I think I am getting somewhere. I had to comment these out since "_data" is not the same shape as the scaler tensor "_scale"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yep this is exactly the dynamic shapes issue that I hit with vasiliy. So this is probably hi-pri to fix.

Right now there’s a constraint with dynamic shapes and wrapper subclasses that we expect your inner tensors to have the same dynamic-ness as the actually wrapper’s shape. Instead, dynamo should try to track the dynamic-ness of each inner tensor individually. And in the FP8 case, it should recognize that the scale is static and doesn’t change sizes.

Copy link

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty cool!

@drisspg drisspg force-pushed the use_mm_in_subclass branch from 47b2bf6 to 4459b89 Compare October 23, 2023 19:36
@drisspg
Copy link
Contributor Author

drisspg commented Oct 23, 2023

@zou3519 @bdhirsh Do you think that this PR gets the tensor subclass version closer to working? If so then I like these changes( I like these changes anyways since we no longer have to track the output amax) and would want merge this in

@drisspg drisspg force-pushed the use_mm_in_subclass branch 2 times, most recently from 5b15bb2 to 9b3ee33 Compare October 25, 2023 20:28
@drisspg drisspg force-pushed the use_mm_in_subclass branch from 2e94b6e to c915267 Compare October 27, 2023 23:34
@drisspg drisspg force-pushed the use_mm_in_subclass branch from 61791d8 to 007e00d Compare October 31, 2023 17:17
@drisspg drisspg requested review from albanD, bdhirsh and vkuzo October 31, 2023 17:34

def to_original_precision(self):
return FromFloat8ConstrFunc.apply(self)

@staticmethod
@torch._dynamo.allow_in_graph
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 without allow_in_graph this breaks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the long-term correct thing here is also:

(1) to_float() remains a (static?) method on the subclass

(2) dynamo eventually understands that it should proxy custom methods on "traceable subclasses" directly into the torch graph

That way dynamo won't try to inline into to_float8() here and try to carefully understand the custom autograd.Function you wrote.

@@ -5,6 +5,7 @@ set -e

pytest test/test_base.py
pytest test/test_sam.py
pytest test/test_compile.py
Copy link
Contributor Author

@drisspg drisspg Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely hate that internal keeps changing the GD file execution permissions on the shell scripts

y = torch.matmul(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y, self.emulate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned offline but food for thought that I'll mention here: it would be interesting to think about what it would take to have aten.matmul(Float8Tensor, Float8Tensor) actually return another Float8Tensor, and then leave it to the subclass to know to upcast on future ops that don't want to handle Float8 directly.

My understanding was:

(1) This is a pain mostly because the extra buffers for float8 live directly on the Float8Linear nn module today and not the subclass (probably for good reason)

(2) Doing this would provide benefit if we want to start increasing the number of ops that directly handle float8, but all we care about is linear then this generality is probably not very useful.

def add_weight_tag(self):
# We add a tag to the weight nn.Parameter in order to signal
# To FSDP that this param is a weight
self.weight._is_fp8_weight = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reviewing the subclass changes but probably not the right person to review this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added in a previous PR and just moved it to the Mixin so that it can be added to the TP stuff

output_dtype = a._orig_dtype
if a._emulate:
assert a._emulate == b._emulate
return torch.ops.aten.mm_float8_emulated(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also just thinking - should emulate just be a global config somewhere, instead of a flag that you have to plumb around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked about this with Brian offline. This is probably right, but I am going to do this in a followup. I also want to see if when get plain torch.nn.fucntional.linear in the LinearFloat8 and will do some matmul changes

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the subclass changes lgtm! (I didn't fully read the non-subclass changes)

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in b518e2a.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants